import torch
import time
import src.reglm.lightning
import argparse
import time
from tqdm import tqdm


args = argparse.ArgumentParser()
# for checkpoint loading
args.add_argument("--model", type=str, default="hyenadna-tiny-16k-seqlen-d128")
args.add_argument("--ckpt", type=str)
args.add_argument("--original_samples", type=str)
args.add_argument("--absorb_mode", type=str, default='threshold')
args.add_argument("--escape_mode",type=str, default='natural')
args.add_argument("--start", type=int)
args.add_argument("--end", type=int)
args.add_argument("--label", type=str)
args.add_argument("--absorb_threshold", type=float, default=0.8)

# HYENADNA_MODEL_NAMES = [
#     "hyenadna-tiny-16k-seqlen-d128",
#     "hyenadna-large-1m-seqlen",
#     "hyenadna-medium-160k-seqlen",
#     "hyenadna-medium-450k-seqlen",
#     "hyenadna-small-32k-seqlen",
#     "hyenadna-tiny-1k-seqlen",
#     "hyenadna-tiny-1k-seqlen-d256",
# ]
args = args.parse_args()



def write_to_fasta(sequences, output_file):
    """
    Write a list of DNA sequences to a FASTA file.

    :param sequences: A list of DNA sequences.
    :param output_file: The name of the FASTA file to write to.
    """
    with open(output_file, 'a') as f:
        for i, seq in enumerate(sequences):
            # Using a generic header format: >sequence_1, >sequence_2, etc.
            f.write(f">sequence_{i+1}\n{seq}\n")

def tensor_to_dna(tensor, inference=False):
    """Convert a 4xn tensor into a DNA sequence.
    If inference is True and the largest logit is 'N', output the index of the second largest value instead."""
    # Define the mapping from one-hot encoding to nucleotide
    mapping = {
        0: 'A',
        1: 'T',
        2: 'G',
        3: 'C',
        4: 'N'
    }

    if inference:
        # Use topk to get indices of the two largest values
        top2_vals, top2_indices = torch.topk(tensor, 2, dim=1)
        # Select the largest values' indices
        discrete_tensor = top2_indices[:, 0]
        # Check if the largest value is 'N'
        is_N = discrete_tensor == 4
        # For elements where largest value is 'N', use the second largest
        discrete_tensor[is_N] = top2_indices[is_N, 1]
    else:
        # Use argmax for normal mode
        discrete_tensor = torch.argmax(tensor, dim=1)

    # Ensure tensor is on CPU and convert to list
    tensor_list = discrete_tensor.cpu().tolist()

    # Map each index to its corresponding nucleotide
    dna_sequence = ''.join(mapping[base] for base in tensor_list)
    return dna_sequence

def escape(seq, max_logists, model, escape_type='natural',threshold=0.3,escape_prob=0.1,max_sub_length=100, label=['0'], rng=None):
    """
    escape condition:
    1. (natural) when max(seq[i+j]) > autoregressive[i+j]
    2. (threshold) when max(seq[i+j]) < threshold
    3. (random) when rand.uniform(0,1) > threshold
    4. (max_length) when i+j > max_sub_length
    ----------------------
    args:
        seq: original seq from 0 to i-1, shape (i-1, 4) -> raw sequence in A,T,G,C format
        max_logists: max logit from i to end, shape (seq_len-i, 4)
        escape_type: 'natural', 'threshold', 'random', or 'max_length'
    """
    assert escape_type in ['natural', 'threshold', 'random', 'max_length'], "escape_type must be one of 'natural', 'threshold', 'random', or 'max_length'"
    rng.manual_seed(int(time.time()))
    # add dumy dimension at batch dimension for the seq and label
    seq = [seq]
    # first implement natural escape
    length = 0
    j = 0
    label_idx  = model.encode_labels(label, add_start=True).to(model.device)
    idxs = model.encode_seqs(seq,add_stop=False).to(model.device) #TODO: use the default
    idxs = torch.cat((label_idx, idxs), dim=1)
    full_seq_flag = True
    for j in range(max_logists.shape[0]):
        #-------------------------------------------
        ##TODO: write this function in the lightning module
        #-------------------------------------------
        # it should return the next token probability and sampled token
        max_next_token_prob, next_token = model.absorb_escape_step(idxs,rng)
        length = j
        if escape_type=='natural' and  max_next_token_prob > max_logists[j]:
            # compare with the normalized max logit from autoregressive model
            full_seq_flag = False
            break
        elif escape_type=='threshold' and  max_next_token_prob < threshold:
            full_seq_flag = False
            break
        elif escape_type=='random' and  torch.rand(1) < escape_prob:
            full_seq_flag = False
            break
        elif escape_type=='max_length' and j > max_sub_length:
            full_seq_flag = False
            break
        idxs = torch.cat((idxs, next_token.unsqueeze(0).unsqueeze(0)), dim=1)
    seq = model.decode(idxs[:, model.label_len + 1 :])
    # print("escape length: ", length)
    # print("escape seq: ", seq[0])
    # print(f"original input seq {original_seq}")
    return length, seq[0], full_seq_flag



def absorb(seq, max_logits, model, absorb_type="threshold", escape_type="natural", threshold=0.9, absorb_prob=0.1, label=0, rng=None):
    """
    absorbing condition:
    1. (threshold) when max(seq[i]) < threshold
    2. (random) when rand.uniform(0,1) < threshold
    ----------------------
    args:
        seq:  a torch tensor with shape (2048, 4)
        max_logit: max logit from i to end, shape (2048, 4)
        absorb_threshold: the threshold for absorbing
    """
    assert absorb_type in ['threshold', 'random'], "absorb_type must be one of 'threshold' or 'random'"
    seq_len = len(seq)
    # get max logit for each token from the torch tensor
    # max_logits = torch.max(seq, dim=-1).values # (seq_len, 2048)
    i = 0
    full_seq_flag = False
    while i < seq_len:
        if absorb_type=='threshold' and max_logits[i] < threshold:
            length, sub_seq, full_seq_flag  = escape(seq[:i], max_logits[i:],model, escape_type=escape_type, rng=rng, label=[label])
            # print(f"escape at index {i + length}")
            # print("!!!!!> seq[i+length:]: ", len(seq[i+length:]), "> sub seq length: ", len(sub_seq), "|| i: ", i, "|| length: ", length)
            seq = sub_seq + seq[i+length:] if not full_seq_flag else sub_seq
            # print("-------------------")
            i = i+length
        elif absorb_type=='random' and torch.rand(1) < absorb_prob:
            # print(f"Absorbing... at {i}")
            length, sub_seq, full_seq_flag  = escape(seq[:i], max_logits[i:],model, escape_type=escape_type, rng=rng)
            seq = sub_seq + seq[i+length:] if not full_seq_flag else sub_seq
            i = i+length
        i += 1
        if full_seq_flag:
            break
    return seq


def main(args):
    model = src.reglm.lightning.LightningModel.load_from_checkpoint(args.ckpt ,model=args.model)
    model = model.to(torch.device(0))
    ABSORB_TYPE = args.absorb_mode
    ESCAPE_TYPE = args.escape_mode
    rng = torch.Generator(device=model.device)
    logits_list = torch.load(args.original_samples)
    # get the softmax of the logits
    normalized_logits_list = [torch.softmax(logits, dim=-1) for logits in logits_list]
    seq = [tensor_to_dna(logits) for logits in logits_list][args.start:args.end]
    max_logits = [torch.max(logits, dim=-1).values for logits in normalized_logits_list][args.start:args.end]
    final_refined_seq = []
    print(args.label)
    for i in tqdm(range(len(seq))):
        print(f"Refining {i} th sequence")
        refined_seq = absorb(seq[i], max_logits[i], model, absorb_type=ABSORB_TYPE, escape_type=ESCAPE_TYPE, threshold=args.absorb_threshold, absorb_prob=0.1, label=args.label, rng =rng)
        final_refined_seq.append(refined_seq)
    reverse_mapping = {
    '00': 'Honey_Bee',
    '01': 'thale_cress',
    '02': 'elegans',
    '03': 'dog',
    '04': 'zebrafish',
    '05': 'fruit_fly',
    '06': 'chicken',
    '07': 'human',
    '08': 'macaque',
    '09': 'mouse',
    '10': 'plasmodium',
    '11': 'rat',
    '12': 'baking_yeast',
    '13': 'fission_yeast',
    '14': 'corn'
    }
    write_to_fasta(final_refined_seq, f"Yeast_Absorb_threshold_{args.absorb_threshold}_{reverse_mapping[args.label]}_Escape_{ESCAPE_TYPE}_Absorb_{ABSORB_TYPE}_{args.model}_num{len(seq)}_start_{args.start}_end_{args.end}.fasta")

if __name__ == "__main__":
    main(args)
